// file:kernel/os/memspace.c
// autor:jiangxinpeng
// date:2021.3.4
// update: 2021.8.9
// copyright:(C) by jiangxinpeng,All right are reserved.

#include <os/memspace.h>
#include <os/vmm.h>
#include <os/debug.h>
#include <os/memcache.h>
#include <os/schedule.h>
#include <arch/page.h>
#include <lib/list.h>

void DumpMemSpace(vmm_t *vmm)
{
    mem_space_t *pos = NULL;
    list_traversal_all_owner_to_next(pos, &vmm->mem_space_head, list)
    {
        KPrint("[mem space] start %x end %x size %x\n", (uint32_t)pos->start, (uint32_t)pos->end, (uint32_t)(pos->end - pos->start));
    }
}

void MemSpaceInsert(vmm_t *vmm, mem_space_t *space)
{
    mem_space_t *pos = NULL;
    mem_space_t *pre = NULL;

    if (vmm != NULL && space != NULL)
    {
        // search space fit pos to doing insert
        list_traversal_all_owner_to_next(pos, &vmm->mem_space_head, list)
        {
            pre = list_pre_owner(pos, list);
            if (space->end <= pos->start)
                break;
        }

        // insert and change list relationship
        list_add_before(&space->list, &pos->list);
        // set space vmmm
        space->vmm = vmm;
        // share mem area no can merge
        if (space->flags & MEM_SPACE_MAP_SHARE)
            return;

        // merge pre and space
        if (pre != NULL && !list_is_head(&pre->list) && pre->end == space->start)
        {
            if (pre->protect == space->protect && pre->flags == space->flags)
            {
                pre->end = space->end;
                list_del(&space->list);
                MemSpaceFree(space);
                // update current space point pre
                space = pre;
            }
        }

        // merge pos and space
        if (pos != NULL && !list_is_head(&pos->list) && pos->start == space->end)
        {
            if (pos->protect == space->protect && pos->flags == space->flags)
            {
                space->end = pos->end;
                list_del(&pos->list);
                MemSpaceFree(pos);
            }
        }
    }
}

uint64_t MemSpaceGetUnMap(vmm_t *vmm, uint32_t len)
{
    uint64_t map_addr = vmm->map_start;
    // try get map area to mem space
    // if no find mem space return
    mem_space_t *space = MemSpaceFind(vmm, map_addr);
    while (space != NULL && list_is_head(&space->list) != 0)
    {
        // space size above limit or map addr give out end
        if (len > USER_VMM_SIZE || map_addr + len > vmm->map_end)
            return -1;
        // if map area above current space above
        // just is unmap area no find map space
        // return map addr
        if (map_addr + len < space->start)
            return map_addr;
        // search next space
        // first,we must be remember update map addr to space end addr
        // in order to try search next space if is present
        map_addr += space->end;
        list_next_owner(space, list);
    }
    return map_addr;
}

// find memspace by address
mem_space_t *MemSpaceFind(vmm_t *vmm, uint64_t addr)
{
    mem_space_t *space = NULL;
    list_traversal_all_owner_to_next(space, &vmm->mem_space_head, list)
    {
        if (addr >= space->start && addr <= space->end)
        {
            return space;
        }
    }
    return NULL;
}

// map a vaddr to pyaddr in memspace
void *MemSpaceMap(vmm_t *vmm, uint64_t addr, uint64_t pyaddr, uint64_t len, uint64_t protect, uint64_t flags)
{
    mem_space_t *space;

    if (vmm != NULL)
    {
        // len aligned to page size
        len = PageAlign(len);
        if (len != 0)
        {
            if (len > USER_VMM_SIZE || addr + len > addr + USER_VMM_SIZE)
            {
                KPrint("[memspace] len %d above user vmm size %d\n", (uint32_t)len, (uint32_t)USER_VMM_SIZE);
                return NULL;
            }
            // fix area map
            if (flags & MEM_SPACE_MAP_FIXED)
            {
                if (addr & PAGEALIGN_MASK)
                {
                    // if space had map before
                    space = MemSpaceFind(vmm, addr);
                    if (space)
                    {
                        KPrint("[memspace] vmm space had been map!\n");
                        return NULL;
                    }
                }
            }
            else
            {
                // map no fixed area
                // try a unmap area
                addr = MemSpaceGetUnMap(vmm, len);
                if (addr < 0)
                {
                    KPrint("[memspace] get unmap space failed!\n");
                    return NULL;
                }
            }

            // remap area
            if (flags & MEM_SPACE_MAP_REMAP)
            {
                protect |= PROTE_REMAP;
            }
            // try alloc new space
            space = MemSpaceAlloc();
            if (space != NULL)
            {
                // according function start argument set new space
                MemSpaceInit(space, addr, addr + len, flags, protect);
                MemSpaceInsert(vmm, space);
            }

            if (flags & MEM_SPACE_MAP_SHARE)
            {
                KPrint("map share page\n");
                // map share mem area
                _MapFixedPage(addr, pyaddr, len, protect);
            }
            else
            {
                KPrint("map no share page\n");
                // no share area
                _MapPage(addr, len, protect);
            }
            return PTYPE(addr);
        }
    }
    return NULL;
}

int MemSpaceUnmap(vmm_t *vmm, uint64_t addr, uint64_t len)
{
    mem_space_t *space;
    mem_space_t *pre;
    mem_space_t *new;

    if (vmm != NULL)
    {
        if (PageAlign(addr) && len < USER_VMM_SIZE && addr < vmm->map_end)
        {
            len = PageAlign(len);
            if (len != 0)
            {
                // get pre and current space point
                pre = MemSpaceFindPre(vmm, addr);
                space = MemSpaceFind(vmm, addr);
                if (!space)
                {
                    KPrint(PRINT_ERR "%s: no found the space!\n");
                    return -1;
                }
                // check addr if in space ranage
                if (addr < space->start || addr + len > space->end)
                {
                    return -1;
                }
                // ummap page
                // just is only clear flags,do not free pyhics pages
                _UnMapPage(addr, len, space->flags & MEM_SPACE_MAP_SHARE);
                // current space addr to addr+len space had free
                // we create a new space and set range
                new = MemSpaceAlloc();
                if (new != NULL)
                {
                    // set new space range
                    new->start = addr + len;
                    new->end = space->end;
                    // update old space end position
                    space->end = addr;
                    list_add_after(&new->list, &space->list);
                    // space free
                    if (space->start == space->end)
                    {
                        list_del(&space->list);
                        MemSpaceFree(space);
                        // current space had free,repoint to pre space
                        space = pre;
                    }
                    // merge space and new space
                    if (space->end == new->start)
                    {
                        space->end = new->end;
                        list_del(&new->list);

                        MemSpaceFree(new);
                    }
                    return 0;
                }
            }
        }
    }
    return -1;
}

void *MemSpaceMapViraddr(vmm_t *vmm, uint32_t addr, uint32_t vaddr, uint32_t len, uint32_t prot, uint32_t flags)
{
    mem_space_t *space, *new;

    if (!vmm)
        return NULL;

    len = PageAlign(len);
    if (!len)
    {
        KPrint(PRINT_ERR "%s: len is zero,please make sure page aligned!\n", __func__);
        return NULL;
    }
    // check limit
    if (len > USER_VMM_SIZE || addr > USER_VMM_TOP_ADDR || addr > USER_VMM_TOP_ADDR - len || addr < USER_VMM_BASE_ADDR)
    {
        KPrint(PRINT_ERR "%s: addr:%x len:%d invalid!\n", addr, len);
        return NULL;
    }

    // fixed map
    if (flags & MEM_SPACE_MAP_FIXED)
    {
        if (addr & PAGEALIGN_MASK)
            return NULL;
        // check space if exist
        space = MemSpaceFind(vmm, addr);
        if (space || addr + len > space->start)
        {
            KPrint(PRINT_ERR "%s: the fixed space had exist! addr:%x\n", __func__, addr);
            return NULL;
        }
    }
    else
    {
        // alloc a new addr
        addr = MemSpaceGetUnMap(vmm, len);
        if (addr < 0)
        {
            KPrint("%s: get unmap space failed!\n", __func__);
            return NULL;
        }
    }

    if (flags & MEM_SPACE_MAP_REMAP)
    {
        prot |= PROTE_REMAP;
    }

    // create new space
    new = MemSpaceAlloc();
    if (!new)
    {
        KPrint(PRINT_ERR "%s: alloc space failed!\n", __func__);
        return NULL;
    }
    MemSpaceInit(new, addr, addr + len, prot, flags);
    MemSpaceInsert(vmm, new);
    // share map
    if (flags & MEM_SPACE_MAP_SHARE)
    {
        uint64_t vend = vaddr + len;
        uint64_t vstart = addr;
        uint64_t paddr;

        while (vaddr < vend)
        {
            paddr = ADDR(KERNEL_VBASE2PYBASE(vaddr));
            _MapFixedPage(vstart, paddr, PAGE_SIZE, prot); // map fixed page
            vaddr += PAGE_SIZE;
            vstart += PAGE_SIZE;
        }
    }
    else
    {
        _MapPage(addr, len, prot); // map new page
    }
}

static uint64_t MemSpaceExpandHeap(vmm_t *vmm, uint64_t addr, uint64_t len)
{
    mem_space_t *space=NULL;
    uint64_t ret=-1, flags=0;

    len = PageAlign(len);
    if (!len)
        return addr;

    // delete space map
    /*ret = MemSpaceUnmap(vmm, addr, len);
    if (ret < 0)
        return ret;*/

    // expend pre space size
    flags = MEM_SPACE_MAP_HEAP;
    if (addr)
    {
        space = MemSpaceFind(vmm, addr - 1);
        if (space && space->end == addr && space->flags == flags)
        {
            space->end = addr + len;
            goto end;
        }

        // alloc new space and insert
        space = MemSpaceAlloc();
        if (!space)
            return -1;
        MemSpaceInit(space, addr, addr + len, PROTE_USER | PROTE_WRITE | PROTE_EXEC, flags);
        MemSpaceInsert(vmm, space);
        _MapPage(addr, len, PROTE_USER | PROTE_WRITE | PROTE_EXEC);
    }
end:
    return addr;
}

// find a memspace to front
mem_space_t *MemSpaceFindPre(vmm_t *vmm, uint64_t addr)
{
    mem_space_t *pos = MemSpaceFind(vmm, addr);

    return list_pre_owner(pos, list);
}

uint32_t SysMemSpaceExpandHeap(uint32_t heap)
{
    uint64_t ret;
    uint64_t old_heap, new_heap;
    vmm_t *vmm = cur_task->vmm;

    // init heap
    if (heap == 0)
    {
        KPrint("[vmm] heap start at %x\n",vmm->heap_start);
        return vmm->heap_start;
    }

    // head size too little
    if (heap < vmm->heap_start)
        goto end;

    new_heap = PageAlign(heap);
    old_heap = PageAlign(vmm->heap_end);
    if (new_heap == old_heap)
        goto set_heap;

    // heap in current heap space,need to reduce heap size
    if (heap <= vmm->heap_end && heap >= vmm->heap_start)
    {
        if (!MemSpaceUnmap(vmm, new_heap, old_heap - new_heap)) // free more space
            goto set_heap;
        KPrint(PRINT_ERR "%s: unmap memspace failed!\n", __func__);
        goto end;
    }

    // heap out of bound
    if (heap > vmm->heap_start + MEM_SPACE_HEAP_SIZE_MAX)
    {
        KPrint(PRINT_ERR "%s: out of head bound!\n", __func__);
        goto end;
    }

    mem_space_t *find = NULL;
    // if space intersection,function return mem_space_t * type
    if ((find = MemSpaceFindIntersection(vmm, old_heap, new_heap + PAGE_SIZE)) != NULL)
    {
        KPrint("%s: space intersection\n");
        goto end;
    }

    // expend heap size
    if (MemSpaceExpandHeap(vmm, old_heap, new_heap - old_heap) != old_heap)
    {
        goto end;
    }
set_heap:
    vmm->heap_end = heap; // set head end
end:
    ret = vmm->heap_end;
    return ret;
}

void MemSpaceTest()
{
    uint32_t init_heap = SysMemSpaceExpandHeap(0);
    KPrint("heap start %x\n", init_heap);
    uint32_t new_heap1 = SysMemSpaceExpandHeap(init_heap + 4096);
    if (new_heap1)
    {
        KPrint("new heap1 extend %x\n", new_heap1);
        memset(init_heap, 0xff, 4096);
        KPrint("write new heap1 ok!\n");
    }

    for (int i = 0; i < 4096; i++)
        if (*(uint8_t *)(init_heap + i) != 0xff)
            KPrint("test mem err!\n");
    KPrint("test ok!\n");

    KPrint("heap addr %x to %x\n", new_heap1, Vbase2Pybase(init_heap));
}